Variational AutoEncoder (VAE)
Adapted from https://github.com/Jackson-Kang/Pytorch-VAE-tutorial
Consider to download this Jupyter Notebook and run locally, or test it with Colab.
Install required packages
! pip install torch
! pip install matplotlib
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.12/site-packages (2.8.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from torch) (75.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.12/site-packages (from torch) (2024.6.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.3)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torchvision.utils import save_image, make_grid
# Model Hyperparameters
dataset_path = '~/datasets'
DEVICE = torch.device("cuda:0" ) if torch.cuda.is_available() else torch.device("cpu" )
batch_size = 100
x_dim = 784
hidden_dim = 400
latent_dim = 200
lr = 1e-3
epochs = 30
Step 1. Load (or download) Dataset
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
mnist_transform = transforms.Compose([
transforms.ToTensor(),
])
kwargs = {'num_workers' : 1 , 'pin_memory' : True }
train_dataset = MNIST(dataset_path, transform= mnist_transform, train= True , download= True )
test_dataset = MNIST(dataset_path, transform= mnist_transform, train= False , download= True )
train_loader = DataLoader(dataset= train_dataset, batch_size= batch_size, shuffle= True , ** kwargs)
test_loader = DataLoader(dataset= test_dataset, batch_size= batch_size, shuffle= False , ** kwargs)
Step 2. Define our model: Variational AutoEncoder (VAE)
"""
A simple implementation of Gaussian MLP Encoder and Decoder
"""
class Encoder(nn.Module):
def __init__ (self , input_dim, hidden_dim, latent_dim):
super (Encoder, self ).__init__ ()
self .FC_input = nn.Linear(input_dim, hidden_dim)
self .FC_input2 = nn.Linear(hidden_dim, hidden_dim)
self .FC_mean = nn.Linear(hidden_dim, latent_dim)
self .FC_var = nn.Linear (hidden_dim, latent_dim)
self .LeakyReLU = nn.LeakyReLU(0.2 )
self .training = True
def forward(self , x):
h_ = self .LeakyReLU(self .FC_input(x))
h_ = self .LeakyReLU(self .FC_input2(h_))
mean = self .FC_mean(h_)
log_var = self .FC_var(h_) # encoder produces mean and log of variance
# (i.e., parateters of simple tractable normal distribution "q"
return mean, log_var
class Decoder(nn.Module):
def __init__ (self , latent_dim, hidden_dim, output_dim):
super (Decoder, self ).__init__ ()
self .FC_hidden = nn.Linear(latent_dim, hidden_dim)
self .FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
self .FC_output = nn.Linear(hidden_dim, output_dim)
self .LeakyReLU = nn.LeakyReLU(0.2 )
def forward(self , x):
h = self .LeakyReLU(self .FC_hidden(x))
h = self .LeakyReLU(self .FC_hidden2(h))
x_hat = torch.sigmoid(self .FC_output(h))
return x_hat
class Model(nn.Module):
def __init__ (self , Encoder, Decoder):
super (Model, self ).__init__ ()
self .Encoder = Encoder
self .Decoder = Decoder
def reparameterization(self , mean, var):
epsilon = torch.randn_like(var).to(DEVICE) # sampling epsilon
z = mean + var* epsilon # reparameterization trick
return z
def forward(self , x):
mean, log_var = self .Encoder(x)
z = self .reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
x_hat = self .Decoder(z)
return x_hat, mean, log_var
encoder = Encoder(input_dim= x_dim, hidden_dim= hidden_dim, latent_dim= latent_dim)
decoder = Decoder(latent_dim= latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
model = Model(Encoder= encoder, Decoder= decoder).to(DEVICE)
Step 3. Define Loss function (reprod. loss) and optimizer
from torch.optim import Adam
BCE_loss = nn.BCELoss()
def loss_function(x, x_hat, mean, log_var):
reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction= 'sum' )
KLD = - 0.5 * torch.sum (1 + log_var - mean.pow (2 ) - log_var.exp())
return reproduction_loss + KLD
optimizer = Adam(model.parameters(), lr= lr)
Step 4. Train Variational AutoEncoder (VAE)
print ("Start training VAE..." )
model.train()
for epoch in range (epochs):
overall_loss = 0
for batch_idx, (x, _) in enumerate (train_loader):
x = x.view(batch_size, x_dim)
x = x.to(DEVICE)
optimizer.zero_grad()
x_hat, mean, log_var = model(x)
loss = loss_function(x, x_hat, mean, log_var)
overall_loss += loss.item()
loss.backward()
optimizer.step()
print (" \t Epoch" , epoch + 1 , "complete!" , " \t Average Loss: " , overall_loss / (batch_idx* batch_size))
print ("Finish!!" )
Start training VAE...
Epoch 1 complete! Average Loss: 175.0465806128704
Epoch 2 complete! Average Loss: 128.90606839850273
Epoch 3 complete! Average Loss: 116.97019198664441
Epoch 4 complete! Average Loss: 112.97430726014711
Epoch 5 complete! Average Loss: 110.6636990198508
Epoch 6 complete! Average Loss: 109.06746030167467
Epoch 7 complete! Average Loss: 107.74997078464106
Epoch 8 complete! Average Loss: 106.72171278302379
Epoch 9 complete! Average Loss: 105.93948053070221
Epoch 10 complete! Average Loss: 105.06694583746348
Epoch 11 complete! Average Loss: 104.39186422357575
Epoch 12 complete! Average Loss: 103.90174390585872
Epoch 13 complete! Average Loss: 103.5466619496035
Epoch 14 complete! Average Loss: 103.15721392750939
Epoch 15 complete! Average Loss: 102.90074143755217
Epoch 16 complete! Average Loss: 102.59622096397642
Epoch 17 complete! Average Loss: 102.38915857483828
Epoch 18 complete! Average Loss: 102.17129713259078
Epoch 19 complete! Average Loss: 101.88609652154632
Epoch 20 complete! Average Loss: 101.79162388159953
Epoch 21 complete! Average Loss: 101.54875074994783
Epoch 22 complete! Average Loss: 101.43628648659224
Epoch 23 complete! Average Loss: 101.28708972962751
Epoch 24 complete! Average Loss: 101.16868274924353
Epoch 25 complete! Average Loss: 100.9885125273894
Epoch 26 complete! Average Loss: 100.94756503351941
Epoch 27 complete! Average Loss: 100.77305422774937
Epoch 28 complete! Average Loss: 100.66967468567404
Epoch 29 complete! Average Loss: 100.5374796372861
Epoch 30 complete! Average Loss: 100.50892499869575
Finish!!
/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
Step 5. Generate images from test dataset
import matplotlib.pyplot as plt
model.eval ()
with torch.no_grad():
for batch_idx, (x, _) in enumerate (tqdm(test_loader)):
x = x.view(batch_size, x_dim)
x = x.to(DEVICE)
x_hat, _, _ = model(x)
break
0%| | 0/100 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
0%| | 0/100 [00:01<?, ?it/s]
Let’s visualize test samples alongside their reconstruction.
def show_image_grid_detailed(x, nrows= 10 , ncols= 10 ):
x = x.view(- 1 , 28 , 28 )
fig, axes = plt.subplots(nrows, ncols, figsize= (12 , 12 ))
for i in range (nrows):
for j in range (ncols):
idx = i * ncols + j
if idx < len (x):
axes[i, j].imshow(x[idx].cpu().numpy(), cmap= 'gray' )
axes[i, j].axis('off' )
plt.tight_layout()
plt.show()
show_image_grid_detailed(x)
show_image_grid_detailed(x_hat)
Step 6. Generate image from noise vector
If q(z|x) is close to N(0, I) “enough”(but not tightly close due to posterior collapse problem), N(0, I) may replace the encoder of VAE.
with torch.no_grad():
noise = torch.randn(batch_size, latent_dim).to(DEVICE)
generated_images = decoder(noise)
show_image_grid_detailed(generated_images)
Back to topCitation For attribution, please cite this work as: